package com.diven.spark.ml.learn.feature

import org.apache.spark.sql.{DataFrame,SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.ml.feature.Binarizer
import org.apache.spark.sql.types.DataTypes
import com.diven.spark.ml.learn.core.BaseTest
import com.diven.spark.ml.learn.core.BaseSpark
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.feature.Interaction

object InteractionTest extends BaseTest {
    
    def apply(): BaseSpark = new InteractionTest()
    
}

class InteractionTest extends BaseSpark{
    
    override def getDataFrame(sparkSession: SparkSession = this.getSparkSession()): DataFrame = {
         sparkSession.createDataFrame(Seq(
              (1, 1, 2, 3, 8, 4, 5),
              (2, 4, 3, 8, 7, 9, 8),
              (3, 6, 1, 9, 2, 3, 6),
              (4, 10, 8, 6, 9, 4, 5),
              (5, 9, 2, 7, 10, 7, 3),
              (6, 1, 1, 4, 2, 8, 4)
        )).toDF("id1", "id2", "id3", "id4", "id5", "id6", "id7")
    }
    
    override def execute(dataFrame: DataFrame) = {
        //数据预处理
        val assembler1 = new VectorAssembler().setInputCols(Array("id2", "id3", "id4")).setOutputCol("vec1")
        val assembler2 = new VectorAssembler().setInputCols(Array("id5", "id6", "id7")).setOutputCol("vec2")
        val assembled = assembler2.transform(assembler1.transform(dataFrame))
        //特征笛卡尔积
        val interaction = new Interaction()
        .setInputCols(Array("id1", "vec1", "vec2"))
        .setOutputCol("interactedCol")
        //转换
        val interacted = interaction.transform(assembled)
        //show
        dataFrame.show()
        interacted.show(truncate = false)
        interacted.printSchema()
    }
    
}